from matplotlib import colormaps, colors,cm
from AuxLibraries import *
import auxiliars as aux
import GWPs

class TRANSITIONRISK():
    def __init__(self):
        pass
    
    def carbon_price_data(self, dtm):
            y = [n for n in range(2014,2025)]
            z = pd.ExcelFile('local_data/CarbonPriceWB_latest.xlsx')
            sn = z.sheet_names
            z = z.parse( 'Compliance_Price')
            z.columns = z.iloc[0]
            z = z.iloc[1:]
            z = z[z['Instrument Type'] == 'ETS']
            c = z['Jurisdiction Covered'].unique()
            res = pd.DataFrame()
            for i in c:
                tmp = z[z['Jurisdiction Covered'] == i]
                region_ = tmp.Region.iloc[0]
                tmp = tmp[y]
                tmp = pd.DataFrame(tmp.replace('-', np.nan).mean(axis = 0), columns = ['CarbonPrice']).reset_index().rename(columns = {0: 'year'})
                tmp['Country'] = [i]*len(tmp)
                tmp['Region'] = [region_]*len(tmp)
                res = pd.concat((res, tmp))
            res['Country'] = res['Country'].replace(GWPs.get_country_name())
            #==
            EU_map = dtm[['SP_GEOGRAPHY', 'Country_adj']].groupby('Country_adj').last().reset_index()
            EU_map = EU_map.rename(columns = {'SP_GEOGRAPHY': 'Country'})
            EU_map = EU_map[EU_map.Country_adj.isin(['Guernsey', 'Indonesia', 'Isle of Man', 'Jersey', 'Russian Federation (the)', 'Åland Islands' ]) == False]
            I = res.copy()
            I['Country'] = I['Country'].replace('EU', 'Europe')
            I = I.merge(EU_map)
            I = I[['year', 'CarbonPrice', 'Country_adj']].rename(columns = {'Country_adj': 'Country'})
            uk = I[I.Country == 'United Kingdom of Great Britain and Northern Ireland (the)']
            uk = uk[uk.year < 2022]
            I = I[I.Country != 'United Kingdom of Great Britain and Northern Ireland (the)']
            I = pd.concat((I, uk))
            res['Country'] = res['Country'].replace(['Saitama', 'Tokyo'], ['Japan', 'Japan'])
            res['Country'] = res['Country'].replace(['California', 'RGGI', 'Washington'], ['United States of America (the)']*3)          

            res = res[['Country', 'year', 'CarbonPrice']].groupby(['Country', 'year']).mean().reset_index()
            res = res[res.Country.isin(['Argentina', 'Australia', 'Canada', 'China', 'Mexico', 
                                        'Kazakhstan', 'Korea (the Republic of)', 'New Zealand', 'Switzerland'
                                        'South Africa',  'Indonesia', 'United Kingdom of Great Britain and Northern Ireland (the)',
                                        'United States of America (the)'])]
            res = pd.concat((res, I))
            res = res.sort_values(by = ['Country', 'year']).rename(columns = {'Country': 'Country_adj'})
            res['mrg'] = res['Country_adj']+'-'+res['year'].astype(int).astype(str)
            res = res[['year', 'CarbonPrice', 'mrg']].groupby('mrg').mean().reset_index()
            res['Country_adj'] = res['mrg'].apply(lambda x: x.split('-')[0])
            res['year'] = res['year'].astype(int)
            return res
    #==
    def TransitionRisk(self, dt, Z, panel_label, order, which_gas = 'CH4', LOGscale=True, HORIZON=100):
        cmp_full,  cmp = aux.get_info_data()
        dx = aux.primary_merge(dt,  cmp)
        maps=  dx[['cdp_id', 'gvkey', 'SP_GEOGRAPHY', 'Country_adj', 'GICS_level_1']].groupby('cdp_id').last().reset_index()
        df = dt.merge(maps, on = 'cdp_id')
        
        #===
        if HORIZON == 100: 
            lat_methane = 28.4
        elif HORIZON == 20: 
            lat_methane = 81.1
        print('#======', lat_methane)
        A = GWPs.Counterfactual(df, Z, [2014,2015,2016,2017,2018,2019,2020,2021], 'AR5', horizon=HORIZON) 
        B = GWPs.Counterfactual(df, Z, [2022, 2023], 'AR6',latest_methane=lat_methane, horizon=HORIZON) 
        X = pd.concat((A, B))
        X['Value (tCO2e)'] = X['Value (tCO2e)'].astype(float)
        if which_gas == 'ALL':
            print('Looking at all gases')
            X = X[X.Gas.isin(['CO2', 'CH4', 'NO2', 'SF6', 'NF3'])]
        elif which_gas == 'CH4':
            print('Looking at methane alone')
            X = X[X.Gas.isin([ 'CH4'])]
        else:
            raise ValueError('Choose the right gas: ALL | CH4')


        X['mrg'] = X['Gas']+'-'+X['gvkey'].astype(int).astype(str)+'-'+X['ReportingYear'].astype(int).astype(str)
        X = X.groupby('mrg').last().reset_index(drop=True)
        X['mrg'] = X['gvkey'].astype(int).astype(str)+'-'+X['ReportingYear'].astype(int).astype(str)
        S = X[['Value (tCO2e)', 'Value (Equalised)', 'mrg']].groupby('mrg').sum().reset_index()
    
        #==
        S['ReportingYear'] = S['mrg'].apply(lambda x: float(x.split('-')[1]))
        S['gvkey'] = S['mrg'].apply(lambda x: int(x.split('-')[0]))
        S['fyear'] = S['ReportingYear']-1
        S['mrg'] = S['gvkey'].astype(int).astype(str)+'-'+S['fyear'].astype(int).astype(str)
        cmp_full['mrg'] = cmp_full['gvkey'].astype(int).astype(str)+'-'+cmp_full['fyear'].astype(int).astype(str)
        der = S.merge(cmp_full[['mrg', 'ebitda_usd', 'ebit_usd']])
        der = der[der.ebitda_usd>0]
        der['ebitda_usd']*=1e6
        der['ebit_usd']*=1e6
        der = der.merge(maps[['gvkey', 'SP_GEOGRAPHY', 'Country_adj', 'GICS_level_1']], on='gvkey' )
        der['GICS_level_1'] = der['GICS_level_1'].replace(['Information Technology', 'Communication Services'],['ICT', 'ICT'])

        #==
        cp = self.carbon_price_data(der)
        der['mrg'] = der['Country_adj'].astype(str)+'-'+der['fyear'].astype(int).astype(str)
        der = der.merge(cp[['mrg', 'CarbonPrice']])
        print('Number of companies in the transition risks analysis:', len(der.gvkey.unique()))
        print('Number of observations in the transition risks analysis:', len(der))
        X = der.copy()
    
        #== Calculate the price of methane emission as proportion of EBITDA
        X['Cost'] = 100*((X['Value (tCO2e)'].astype(float)*X['CarbonPrice'])/X['ebitda_usd'] )
        X['CounterfactualCost'] = 100*((X['Value (Equalised)'].astype(float)*X['CarbonPrice'])/X['ebitda_usd'] )
        #== Take the difference of the two: COST - COUNTERFACTUAL (notice the minus sign)
        X['differential_risk'] = -(X['Cost'] - X['CounterfactualCost'])\
        #== Just a quick clean to ensure there is no firm whenre methane costs more than EBITDA
        X = X[X['Cost']<100]
        X = X[ X['differential_risk'] <  X['differential_risk'].max()]
        X['GICS_level_1'] = X['GICS_level_1'].replace('Consumer Staple', 'Consumer Staples')
        #== Plot it 
        cols = list(order.values())
        plt.figure()
        ax = plt.subplot()
        count=0
        for i in list(order.keys()):
            differential_risk = X[X.GICS_level_1 == i]['differential_risk']
            sns.ecdfplot(data = differential_risk,color=order[i], 
                         log_scale=False, 
                         ax = ax,lw=2.5, 
                         complementary=True); 
            count+=1

        differential_risk = X['differential_risk']
        sns.ecdfplot(data = differential_risk,color='k', 
                     log_scale=False,ax = ax,lw=2,
                     ls='-.',label='Full sample', 
                     complementary=True); 
        ax.set_yscale('log')
        plt.xlabel('Earnings at Risk, %')
        plt.ylabel('Complementary CDF')
        ax.spines[['right']].set_visible(False)

        count=0
        axt = ax.twiny()
        for i in list(order.keys()):
            axt.axvline(x = X[X.GICS_level_1 == i]['Cost'].mean(), ymin = 0.8, ymax = 1, lw=3, ls= '--', c= order[i])
            print('Average reported methane costs:', i, round(X[X.GICS_level_1 == i]['Cost'].mean(),2))
            count+=1
        axt.set_xlabel('Methane costs over earnings (as reported), %')
        axt.spines[['right']].set_visible(False)
        ax.text(-0.05, 1.1, panel_label[0], transform=ax.transAxes,
                fontsize=24, fontweight='bold', va='top', ha='right')
        if HORIZON == 20: ax.set_ylabel('')
        return X


    def TransitionRisk_BYREGION(self, data_type, panel_label, ax2,H,ordered=[],legend = True):
        ciso = pd.read_excel('local_data/country_iso.xlsx')
        isor = pd.read_excel('local_data/iso_region.xlsx')
        isor['Region'] = isor['Region'].apply(lambda x: x.split('+ \xa0')[1])
        ciso = ciso[['Country', 'ISO-3']]
        ciso.columns = ['Country', 'loc' ]
        ciso = isor.merge(ciso)

        data = data_type
        data['differential_risk'] = data['differential_risk'].astype(float)
        data['Country_adj'] = data['Country_adj'].replace(GWPs.get_country_name())
        data = data.merge(ciso[['Country', 'loc']].rename(columns = {'Country': 'Country_adj'}))
        cmap =  cm.get_cmap('coolwarm')
        cols = [colors.rgb2hex(cmap(i)) for i in np.linspace(0,1,5)]
        isor = pd.read_excel('local_data/iso_region.xlsx')
        isor['Region'] = isor['Region'].apply(lambda x: x.split('+ \xa0')[1])
        data = data.merge(isor)
        data['Region'] = data['Region'].replace(['Australia and New Zealand', 'Central Asia', 'Eastern Asia'], ['Asia-Pacific']*3)
        data = data[data.Region.isin(['Northern America', 'Northern Europe', 'Southern Europe', 'Western Europe', 'Asia-Pacific'])]

        if len(ordered) == 0:
            ordered =  data[['Region', 'differential_risk']].groupby('Region').quantile(0.99).sort_values(by = 'differential_risk', ascending=True)


        count=0
        sort_type = ordered.index  
        for i in sort_type: #:
            tmp = data[data.Region == i]['differential_risk']
            sns.ecdfplot(data = tmp,color=cols[count], ax = ax2, lw=2, log_scale=False, complementary = True); 
            count+=1
        ax2.set_yscale('log')
        if legend:
            plt.legend(labels = list(sort_type), loc = 'center left', bbox_to_anchor = (1,0.5))
        plt.xlabel('Earnings ar risk, %')
        plt.ylabel('Complementary CDF')
        ax2.spines[['right']].set_visible(False)
        props = dict(boxstyle='round', facecolor='#7ac2e0', alpha=0.5)
        if H==100:
            plt.text(0.05, 0.1, 'Counterfactual = GWP$_{100}$', bbox=props) #,fontsize=18)
        else:
            plt.text(5, 0.1, 'Counterfactual = GWP$_{20}$', bbox=props) #,fontsize=18)

        count=0
        axt = ax2.twiny()
        for i in sort_type:
            axt.axvline(x = 100*data[data.Region == i]['Cost'].mean(), ymin = 0.8, ymax = 1, lw=3, ls= '--', c= cols[count])
            print('Average reported methane costs:', i, round(100*data[data.Region == i]['Cost'].mean(),2))
            count+=1
        axt.set_xlabel('Methane costs over earnings (as reported), %')
        ax2.text(-0.05, 1.1, panel_label[0], transform=ax2.transAxes,
                fontsize=24, fontweight='bold', va='top', ha='right')
        if H == 100:
            return data, ordered
        else:
            return data